{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Sequence classification model for IMDB Sentiment Analysis\n",
    "(c) Deniz Yuret, 2019\n",
    "* Objectives: Learn the structure of the IMDB dataset and train a simple RNN model.\n",
    "* Prerequisites: [RNN models](60.rnn.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set display width, load packages, import symbols\n",
    "ENV[\"COLUMNS\"] = 72\n",
    "using Pkg; for p in (\"Knet\",\"IterTools\"); haskey(Pkg.installed(),p) || Pkg.add(p); end\n",
    "using Statistics: mean\n",
    "using IterTools: ncycle\n",
    "using Knet: Knet, AutoGrad, RNN, param, dropout, minibatch, nll, accuracy, progress!, adam, save, load, gc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0e-8"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Set constants for the model and training\n",
    "EPOCHS=3          # Number of training epochs\n",
    "BATCHSIZE=64      # Number of instances in a minibatch\n",
    "EMBEDSIZE=125     # Word embedding size\n",
    "NUMHIDDEN=100     # Hidden layer size\n",
    "MAXLEN=150        # maximum size of the word sequence, pad shorter sequences, truncate longer ones\n",
    "VOCABSIZE=30000   # maximum vocabulary size, keep the most frequent 30K, map the rest to UNK token\n",
    "NUMCLASS=2        # number of output classes\n",
    "DROPOUT=0.5       # Dropout rate\n",
    "LR=0.001          # Learning rate\n",
    "BETA_1=0.9        # Adam optimization parameter\n",
    "BETA_2=0.999      # Adam optimization parameter\n",
    "EPS=1e-08         # Adam optimization parameter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and view data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "imdb"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "include(Knet.dir(\"data\",\"imdb.jl\"))   # defines imdb loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "slideshow": {
     "slide_type": "skip"
    }
   },
   "outputs": [
    {
     "data": {
      "text/latex": [
       "\\begin{verbatim}\n",
       "imdb()\n",
       "\\end{verbatim}\n",
       "Load the IMDB Movie reviews sentiment classification dataset from https://keras.io/datasets and return (xtrn,ytrn,xtst,ytst,dict) tuple.\n",
       "\n",
       "\\section{Keyword Arguments:}\n",
       "\\begin{itemize}\n",
       "\\item url=https://s3.amazonaws.com/text-datasets: where to download the data (imdb.npz) from.\n",
       "\n",
       "\n",
       "\\item dir=Pkg.dir(\"Knet/data\"): where to cache the data.\n",
       "\n",
       "\n",
       "\\item maxval=nothing: max number of token values to include. Words are ranked by how often they occur (in the training set) and only the most frequent words are kept. nothing means keep all, equivalent to maxval = vocabSize + pad + stoken.\n",
       "\n",
       "\n",
       "\\item maxlen=nothing: truncate sequences after this length. nothing means do not truncate.\n",
       "\n",
       "\n",
       "\\item seed=0: random seed for sample shuffling. Use system seed if 0.\n",
       "\n",
       "\n",
       "\\item pad=true: whether to pad short sequences (padding is done at the beginning of sequences). pad\\_token = maxval.\n",
       "\n",
       "\n",
       "\\item stoken=true: whether to add a start token to the beginning of each sequence. start\\_token = maxval - pad.\n",
       "\n",
       "\n",
       "\\item oov=true: whether to replace words >= oov\\emph{token with oov}token (the alternative is to skip them). oov\\_token = maxval - pad - stoken.\n",
       "\n",
       "\\end{itemize}\n"
      ],
      "text/markdown": [
       "```\n",
       "imdb()\n",
       "```\n",
       "\n",
       "Load the IMDB Movie reviews sentiment classification dataset from https://keras.io/datasets and return (xtrn,ytrn,xtst,ytst,dict) tuple.\n",
       "\n",
       "# Keyword Arguments:\n",
       "\n",
       "  * url=https://s3.amazonaws.com/text-datasets: where to download the data (imdb.npz) from.\n",
       "  * dir=Pkg.dir(\"Knet/data\"): where to cache the data.\n",
       "  * maxval=nothing: max number of token values to include. Words are ranked by how often they occur (in the training set) and only the most frequent words are kept. nothing means keep all, equivalent to maxval = vocabSize + pad + stoken.\n",
       "  * maxlen=nothing: truncate sequences after this length. nothing means do not truncate.\n",
       "  * seed=0: random seed for sample shuffling. Use system seed if 0.\n",
       "  * pad=true: whether to pad short sequences (padding is done at the beginning of sequences). pad_token = maxval.\n",
       "  * stoken=true: whether to add a start token to the beginning of each sequence. start_token = maxval - pad.\n",
       "  * oov=true: whether to replace words >= oov*token with oov*token (the alternative is to skip them). oov_token = maxval - pad - stoken.\n"
      ],
      "text/plain": [
       "\u001b[36m  imdb()\u001b[39m\n",
       "\n",
       "  Load the IMDB Movie reviews sentiment classification dataset from\n",
       "  https://keras.io/datasets and return (xtrn,ytrn,xtst,ytst,dict)\n",
       "  tuple.\n",
       "\n",
       "\u001b[1m  Keyword Arguments:\u001b[22m\n",
       "\u001b[1m  ≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡\u001b[22m\n",
       "\n",
       "    •    url=https://s3.amazonaws.com/text-datasets: where to\n",
       "        download the data (imdb.npz) from.\n",
       "\n",
       "    •    dir=Pkg.dir(\"Knet/data\"): where to cache the data.\n",
       "\n",
       "    •    maxval=nothing: max number of token values to include.\n",
       "        Words are ranked by how often they occur (in the training\n",
       "        set) and only the most frequent words are kept. nothing\n",
       "        means keep all, equivalent to maxval = vocabSize + pad +\n",
       "        stoken.\n",
       "\n",
       "    •    maxlen=nothing: truncate sequences after this length.\n",
       "        nothing means do not truncate.\n",
       "\n",
       "    •    seed=0: random seed for sample shuffling. Use system seed\n",
       "        if 0.\n",
       "\n",
       "    •    pad=true: whether to pad short sequences (padding is done\n",
       "        at the beginning of sequences). pad_token = maxval.\n",
       "\n",
       "    •    stoken=true: whether to add a start token to the beginning\n",
       "        of each sequence. start_token = maxval - pad.\n",
       "\n",
       "    •    oov=true: whether to replace words >= oov\u001b[4mtoken with\n",
       "        oov\u001b[24mtoken (the alternative is to skip them). oov_token =\n",
       "        maxval - pad - stoken."
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@doc imdb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Info: Loading IMDB...\n",
      "└ @ Main /kuacc/users/dyuret/.julia/dev/Knet/data/imdb.jl:57\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  6.578213 seconds (26.04 M allocations: 1.354 GiB, 9.87% gc time)\n"
     ]
    }
   ],
   "source": [
    "@time (xtrn,ytrn,xtst,ytst,imdbdict)=imdb(maxlen=MAXLEN,maxval=VOCABSIZE);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "25000-element Array{Array{Int32,1},1}\n",
      "25000-element Array{Int8,1}\n",
      "25000-element Array{Array{Int32,1},1}\n",
      "25000-element Array{Int8,1}\n",
      "Dict{String,Int32} with 88584 entries\n"
     ]
    }
   ],
   "source": [
    "println.(summary.((xtrn,ytrn,xtst,ytst,imdbdict)));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1×150 LinearAlgebra.Adjoint{Int32,Array{Int32,1}}:\n",
       " 30000  30000  30000  30000  …  815  9  39  842  40  89  714  9"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Words are encoded with integers\n",
    "rand(xtrn)'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1×25000 LinearAlgebra.Adjoint{Int64,Array{Int64,1}}:\n",
       " 150  150  150  150  150  150  150  …  150  150  150  150  150  150"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Each word sequence is padded or truncated to length 150\n",
    "length.(xtrn)'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "reviewstring (generic function with 2 methods)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Define a function that can print the actual words:\n",
    "imdbvocab = Array{String}(undef,length(imdbdict))\n",
    "for (k,v) in imdbdict; imdbvocab[v]=k; end\n",
    "imdbvocab[VOCABSIZE-2:VOCABSIZE] = [\"<unk>\",\"<s>\",\"<pad>\"]\n",
    "function reviewstring(x,y=0)\n",
    "    x = x[x.!=VOCABSIZE] # remove pads\n",
    "    \"\"\"$((\"Sample\",\"Negative\",\"Positive\")[y+1]) review:\\n$(join(imdbvocab[x],\" \"))\"\"\"\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Positive review:\n",
      "in the film is sung in fact it contains no more songs than a regular musical it is actually a lot more like a chaplin or buster keaton or marx brothers film my criticisms of the disc are not that important heck criterion has the right to smack me around for making those complaints the fact is their people probably spent hundreds of hours fixing up a film which only 20 now 21 people have voted for on imdb and only about a hundred people if that will ever see the film heck if you look at the criterion web site le million is nowhere to be found i have no clue why not it's something they should really be proud of of course their web site is surprisingly horrible they did a fine job on this film bravo they deserve all the money i can stand to give them\n"
     ]
    }
   ],
   "source": [
    "# Hit Ctrl-Enter to see random reviews:\n",
    "r = rand(1:length(xtrn))\n",
    "println(reviewstring(xtrn[r],ytrn[r]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1×25000 LinearAlgebra.Adjoint{Int8,Array{Int8,1}}:\n",
       " 2  1  2  1  1  1  2  2  1  1  2  …  1  1  1  2  2  2  2  1  2  1  1"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Here are the labels: 1=negative, 2=positive\n",
    "ytrn'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "## Define the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "struct SequenceClassifier; input; rnn; output; pdrop; end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SequenceClassifier"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "SequenceClassifier(input::Int, embed::Int, hidden::Int, output::Int; pdrop=0) =\n",
    "    SequenceClassifier(param(embed,input), RNN(embed,hidden,rnnType=:gru), param(output,hidden), pdrop)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "function (sc::SequenceClassifier)(input)\n",
    "    embed = sc.input[:, permutedims(hcat(input...))]\n",
    "    embed = dropout(embed,sc.pdrop)\n",
    "    hidden = sc.rnn(embed)\n",
    "    hidden = dropout(hidden,sc.pdrop)\n",
    "    return sc.output * hidden[:,:,end]\n",
    "end\n",
    "\n",
    "(sc::SequenceClassifier)(input,output) = nll(sc(input),output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(390, 390)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dtrn = minibatch(xtrn,ytrn,BATCHSIZE;shuffle=true)\n",
    "dtst = minibatch(xtst,ytst,BATCHSIZE)\n",
    "length.((dtrn,dtst))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "trainresults (generic function with 1 method)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# For running experiments\n",
    "function trainresults(file,maker; o...)\n",
    "    if (print(\"Train from scratch? \"); readline()[1]=='y')\n",
    "        model = maker()\n",
    "        progress!(adam(model,ncycle(dtrn,EPOCHS);lr=LR,beta1=BETA_1,beta2=BETA_2,eps=EPS))\n",
    "        Knet.save(file,\"model\",model)\n",
    "        Knet.gc() # To save gpu memory\n",
    "    else\n",
    "        isfile(file) || download(\"http://people.csail.mit.edu/deniz/models/tutorial/$file\",file)\n",
    "        model = Knet.load(file,\"model\")\n",
    "    end\n",
    "    return model\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "maker (generic function with 1 method)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "maker() = SequenceClassifier(VOCABSIZE,EMBEDSIZE,NUMHIDDEN,NUMCLASS,pdrop=DROPOUT)\n",
    "# model = maker()\n",
    "# nll(model,dtrn), nll(model,dtst), accuracy(model,dtrn), accuracy(model,dtst)\n",
    "# (0.69312066f0, 0.69312423f0, 0.5135817307692307, 0.5096153846153846)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train from scratch? stdin> n\n"
     ]
    }
   ],
   "source": [
    "model = trainresults(\"imdbmodel132.jld2\",maker);\n",
    "# ┣████████████████████┫ [100.00%, 1170/1170, 00:15/00:15, 76.09i/s]\n",
    "# nll(model,dtrn), nll(model,dtst), accuracy(model,dtrn), accuracy(model,dtst)\n",
    "# (0.05217469f0, 0.3827392f0, 0.9865785256410257, 0.8576121794871795)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Playground"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "str2ids (generic function with 1 method)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictstring(x)=\"\\nPrediction: \" * (\"Negative\",\"Positive\")[argmax(Array(vec(model([x]))))]\n",
    "UNK = VOCABSIZE-2\n",
    "str2ids(s::String)=[(i=get(imdbdict,w,UNK); i>=UNK ? UNK : i) for w in split(lowercase(s))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Negative review:\n",
      "<s> full disclosure i was born in 1967 at first the premise tickled me after all if you were a teenager growing up in the age of reagan a trip down memory lane was worth a laugh or three pop rocks atari and rock 'em sock 'em robots loved 'em not despite the fact they were goofy but because they were goofy and silly and fun but when vh1 decided to make the series i love last tuesday i knew enough was enough goes hal sparks have nothing better to do than read from a <unk> how idiotic are wait let me check hmmm no no it appears he doesn't snarky snotty uber hip posturing has its place but enough already\n",
      "\n",
      "Prediction: Negative\n"
     ]
    }
   ],
   "source": [
    "# Here we can see predictions for random reviews from the test set; hit Ctrl-Enter to sample:\n",
    "r = rand(1:length(xtst))\n",
    "println(reviewstring(xtst[r],ytst[r]))\n",
    "println(predictstring(xtst[r]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "stdin> i love this movie\n",
      "\n",
      "Prediction: Positive\n"
     ]
    }
   ],
   "source": [
    "# Here the user can enter their own reviews and classify them:\n",
    "println(predictstring(str2ids(readline(stdin))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": null,
   "lastKernelId": null
  },
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "julia.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Julia 1.3.1",
   "language": "julia",
   "name": "julia-1.3"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.3.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
